{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "name": "DC-GAN.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.10"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tRqlMiFgQhxE"
      },
      "source": [
        "<h3>  &nbsp;&nbsp;Training DC-GAN using Colab Cloud TPU&nbsp;&nbsp; <a href=\"https://cloud.google.com/tpu/\"><img valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"></a></h3>\n",
        "\n",
        "* On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
        "* The cell below makes sure you have access to a TPU on Colab."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "nxtjkPBWQhxF",
        "colab": {}
      },
      "source": [
        "import os\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "avWYbx7peHGq"
      },
      "source": [
        "### [RUNME] Install Colab TPU compatible PyTorch/TPU wheels and dependencies\n",
        "This may take up to ~2 minutes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "vJZrkoejQhxK",
        "colab": {}
      },
      "source": [
        "VERSION = \"nightly\"  #@param [\"1.5\" , \"20200325\", \"nightly\"]\n",
        "!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "!python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "p8Ji_quhQhxO"
      },
      "source": [
        "## Generative Adersarial Networks (GANs)\n",
        "\n",
        "In the landmark paper [Goodfellow et al.](https://arxiv.org/abs/1406.2661) , published in 2014, authors introduced this novel paradigm for generative models. The fundamental idea proposed in the work is to train a Generator Network in adversarial setup, where a discriminator network downstream critiques the generated samples.\n",
        "\n",
        "Simply put, generator network generates a sample and discriminator network classifies it as a real or fake. Discriminator is also provided with real samples. The objective functions takes the following form:\n",
        "\n",
        "$$\\underset{G}{\\text{minimize}}\\; \\underset{D}{\\text{maximize}}\\; \\mathbb{E}_{x \\sim p_\\text{data}}\\left[\\log D(x)\\right] + \\mathbb{E}_{z \\sim p(z)}\\left[\\log \\left(1-D(G(z))\\right)\\right]$$\n",
        "where: <br>\n",
        "$x \\sim p_\\text{data}$ are samples from the input data.\n",
        "$z \\sim p(z)$ are the random noise samples.\n",
        "$G(z)$ are the generated images using the neural network generator $G$, and $D$ is the output of the discriminator, specifying the probability of an input being real."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kUOhUi2HQhxO"
      },
      "source": [
        "## Training Setup\n",
        "\n",
        "This example illustrates distributed (data parallel) training of DC-GAN model using MNIST dataset on a TPU device. A TPU device consistes of 4 chips (8 cores; 2 cores/chip). Both the discriminator and generator replica are created on each of 8 cores. The dataset is splitted across the 8 cores. \n",
        "<br> At every training step, each of the cores perfoms the forward (loss computation) and backward (gradient computation) on the given minibatch and then [all_reduce](https://www.tensorflow.org/xla/operation_semantics#allreduce) is performed across TPU cores to update the parameters. Notice `xm.optimizer_step` call in the discriminator and optimizer train steps.\n",
        "\n",
        "General GAN training looks like:\n",
        "\n",
        "1. update the **generator** ($G$) to minimize the probability of the __discriminator making the correct choice__. \n",
        "2. update the **discriminator** ($D$) to maximize the probability of the __discriminator making the correct choice__.\n",
        "\n",
        "We will use a different objective when we update the generator: maximize the probability of the **discriminator making the incorrect choice**. This small change helps to alleviate problems with the generator gradient vanishing when the discriminator is confident. This is the standard update used in most GAN papers, and was used in the original paper from [Goodfellow et al.](https://arxiv.org/abs/1406.2661). \n",
        "\n",
        "Therefore the training loop in this notebook will entail:\n",
        "1. Update the generator ($G$) to maximize the probability of the discriminator making the incorrect choice on generated data:\n",
        "$$\\underset{G}{\\text{maximize}}\\;  \\mathbb{E}_{z \\sim p(z)}\\left[\\log D(G(z))\\right]$$\n",
        "2. Update the discriminator ($D$), to maximize the probability of the discriminator making the correct choice on real and generated data:\n",
        "$$\\underset{D}{\\text{maximize}}\\; \\mathbb{E}_{x \\sim p_\\text{data}}\\left[\\log D(x)\\right] + \\mathbb{E}_{z \\sim p(z)}\\left[\\log \\left(1-D(G(z))\\right)\\right]$$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "o6r3OlA_QhxP",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torch import nn, optim\n",
        "from torchvision import transforms, datasets\n",
        "from torch.optim import Adam\n",
        "import torch.nn.functional as F\n",
        "\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "import torch_xla.utils.utils as xu"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AYxAwFtDQhxR"
      },
      "source": [
        "# Setting up the Global Flags\n",
        "\n",
        "In the current setup, Discriminator network was chosen to be a smaller capacity than generator. Even with similar capacity networks, generator update path is deeper than discriminator. Therefore uneven learning rates chosen here seems to yield a better convergence. \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "fDStxFFGQhxR",
        "colab": {}
      },
      "source": [
        "# Define Parameters\n",
        "FLAGS = {}\n",
        "FLAGS['datadir'] = \"/tmp/mnist\"\n",
        "FLAGS['batch_size'] = 128\n",
        "FLAGS['num_workers'] = 4\n",
        "FLAGS['gen_learning_rate'] = 0.005\n",
        "FLAGS['disc_learning_rate'] = 0.001\n",
        "FLAGS['num_epochs'] = 21\n",
        "FLAGS['num_cores'] = 8  "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "R40fkk90QhxT",
        "colab": {}
      },
      "source": [
        "from matplotlib.pyplot import imshow\n",
        "from matplotlib import pyplot as plt\n",
        "from IPython import display \n",
        "\n",
        "from google.colab.patches import cv2_imshow\n",
        "import cv2\n",
        "    \n",
        "RESULT_IMG_PATH = '/tmp/test_result.png'\n",
        "\n",
        "def plot_results(*images):\n",
        "    num_images = len(images)\n",
        "    n_rows = 4\n",
        "    n_columns =len(images) // n_rows\n",
        "    fig, axes = plt.subplots(n_rows, n_columns, figsize=(11, 9))\n",
        "\n",
        "    for i, ax in enumerate(fig.axes):\n",
        "        ax.axis('off') \n",
        "        if i >= num_images:\n",
        "          continue\n",
        "        img = images[i]\n",
        "        img = img.squeeze() # [1,Y,X] -> [Y,X]\n",
        "        ax.imshow(img)\n",
        "    plt.savefig(RESULT_IMG_PATH, transparent=True)\n",
        "\n",
        "def display_results():\n",
        "    img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED)\n",
        "    cv2_imshow(img)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "QH6HiMRaQhxW",
        "colab": {}
      },
      "source": [
        "def mnist_data():\n",
        "    compose = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])\n",
        "    out_dir = '{}/dataset'.format(FLAGS['datadir'])\n",
        "    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Pjtbtac9QhxY",
        "colab": {}
      },
      "source": [
        "class DiscriminativeNet(torch.nn.Module):\n",
        "    \n",
        "    def __init__(self):\n",
        "        super(DiscriminativeNet, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)\n",
        "        self.bn1 = nn.BatchNorm2d(32)\n",
        "        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)\n",
        "        self.bn2 = nn.BatchNorm2d(64)\n",
        "        self.fc1 = nn.Linear(4*4*64, 1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = F.leaky_relu(F.max_pool2d(self.conv1(x), 2), 0.01)\n",
        "        x = self.bn1(x)\n",
        "        x = F.leaky_relu(F.max_pool2d(self.conv2(x), 2), 0.01)\n",
        "        x = self.bn2(x)\n",
        "        x = torch.flatten(x, 1)\n",
        "        x = F.leaky_relu(self.fc1(x), 0.01)\n",
        "        return torch.sigmoid(x)            \n",
        "        "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rATu6mC2Qhxa",
        "colab": {}
      },
      "source": [
        "class GenerativeNet(torch.nn.Module):\n",
        "    \n",
        "    def __init__(self):\n",
        "        super(GenerativeNet, self).__init__()\n",
        "        self.input_size = 100\n",
        "        self.linear1 = nn.Linear(self.input_size, 1024)\n",
        "        self.bn1 = nn.BatchNorm1d(1024)\n",
        "        self.linear2 = nn.Linear(1024, 7*7*128)\n",
        "        self.bn2 = nn.BatchNorm1d(7*7*128)\n",
        "        self.conv1 = nn.ConvTranspose2d(\n",
        "            in_channels=128, \n",
        "            out_channels=64, \n",
        "            kernel_size=4,\n",
        "            stride=2, \n",
        "            padding=1, \n",
        "            bias=False\n",
        "        )\n",
        "        self.bn3 = nn.BatchNorm2d(64)\n",
        "        self.conv2 = nn.ConvTranspose2d(\n",
        "            in_channels=64, \n",
        "            out_channels=1, \n",
        "            kernel_size=4,\n",
        "            stride=2, \n",
        "            padding=1, \n",
        "            bias=False\n",
        "        )\n",
        "\n",
        "    # Noise\n",
        "    def generate_noise(self, size):\n",
        "        n = torch.randn(size, self.input_size)\n",
        "        return n \n",
        "              \n",
        "    def forward(self, x):\n",
        "        x = self.linear1(x)\n",
        "        x = F.relu(x)\n",
        "        x = self.bn1(x)\n",
        "        x = self.linear2(x)\n",
        "        x = F.relu(x)\n",
        "        x = self.bn2(x)\n",
        "        x = x.view(x.shape[0], 128, 7, 7)\n",
        "        x = self.conv1(x)\n",
        "        x = F.relu(x)\n",
        "        x = self.bn3(x)\n",
        "        x = self.conv2(x)\n",
        "        x = torch.tanh(x)\n",
        "        return x"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3oBrhFUEQhxc",
        "colab": {}
      },
      "source": [
        "def init_weights(m):\n",
        "    classname = m.__class__.__name__\n",
        "    if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:\n",
        "        m.weight.data.normal_(0.00, 0.02)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "EZ45r5UvQhxe",
        "colab": {}
      },
      "source": [
        "def real_data_target(size, device):\n",
        "    '''\n",
        "    Tensor containing ones, with shape = size\n",
        "    '''\n",
        "    data = torch.ones(size, 1)\n",
        "    return data.to(device)\n",
        "\n",
        "def fake_data_target(size, device):\n",
        "    '''\n",
        "    Tensor containing zeros, with shape = size\n",
        "    '''\n",
        "    data = torch.zeros(size, 1)\n",
        "    return data.to(device)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ynBtejcpQhxj"
      },
      "source": [
        "# Note on the use of .detach() method\n",
        "\n",
        "You will notice in the following code snippet that when the generator is used to create the `fake_data`, .detach() for the discriminator training step, the .detach call is used to create a new view of the fake_data tensor for which the operations will not be recorded for gradient computation.<br>\n",
        "\n",
        "Since fake_date is an output of an nn.module, by default, pytorch will record all the operations performed on this tensor during the forward pass as DAG. And after the backward pass these DAG and corresponding operations are cleared (unless `retain_graph=True`).\n",
        "Therefore such a tensor can be part of only one cone of logic where the forward and backward pass is done. If there are two loss function where this tensor is used and backward pass is performed on these two function (or even sum of the functions) for the second backward pass the operations DAG will not be found, leading to an error.\n",
        "\n",
        "The second place, where detach() is used is when a numpy() call is to be made to tensor (for plotting purposes). Pytorch also requires that requires_grad should not be true on these tensor. (Ref: `\n",
        "RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.`)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "OFzyeN_GQhxk",
        "colab": {}
      },
      "source": [
        "SERIAL_EXEC = xmp.MpSerialExecutor()\n",
        "# Only instantiate model weights once in memory.\n",
        "generator = GenerativeNet()\n",
        "generator.apply(init_weights)\n",
        "descriminator = DiscriminativeNet()\n",
        "descriminator.apply(init_weights)\n",
        "WRAPPED_GENERATOR = xmp.MpModelWrapper(generator)\n",
        "WRAPPED_DISCRIMINATOR = xmp.MpModelWrapper(descriminator)\n",
        "\n",
        "def train_gan(rank):\n",
        "    torch.manual_seed(1) \n",
        "    data = SERIAL_EXEC.run(lambda: mnist_data())\n",
        "    train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
        "        data,\n",
        "        num_replicas=xm.xrt_world_size(),\n",
        "        rank=xm.get_ordinal(),\n",
        "        shuffle=True)\n",
        "    \n",
        "\n",
        "    # Create loader with data, so that we can iterate over it\n",
        "    train_loader = torch.utils.data.DataLoader(\n",
        "      data,\n",
        "      batch_size=FLAGS['batch_size'],\n",
        "      sampler=train_sampler,\n",
        "      num_workers=FLAGS['num_workers'],\n",
        "      drop_last=True)\n",
        "\n",
        "    # Num batches\n",
        "    num_batches = len(train_loader)\n",
        "    \n",
        "    device = xm.xla_device()\n",
        "    \n",
        "    generator = WRAPPED_GENERATOR.to(device)\n",
        "    discriminator = WRAPPED_DISCRIMINATOR.to(device)\n",
        "   \n",
        "    \n",
        "    # Optimizers\n",
        "    d_optimizer = Adam(discriminator.parameters(), lr=FLAGS['disc_learning_rate'], betas=(0.5, 0.999))\n",
        "    g_optimizer = Adam(generator.parameters(), lr=FLAGS['gen_learning_rate'], betas=(0.5, 0.999))\n",
        "\n",
        "    # Number of epochs\n",
        "    num_epochs = FLAGS['num_epochs'] \n",
        "    # Loss function\n",
        "    loss = nn.BCELoss()\n",
        "    \n",
        "\n",
        "    def train_step_discriminator(optimizer, real_data, fake_data, device):         \n",
        "        # Reset gradients\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        # 1. Train on Real Data\n",
        "        prediction_real = discriminator(real_data)\n",
        "        # Calculate error and backpropagate\n",
        "        error_real = loss(prediction_real, real_data_target(real_data.size(0), device))\n",
        "        \n",
        "\n",
        "        # 2. Train on Fake Data\n",
        "        prediction_fake = discriminator(fake_data)\n",
        "        # Calculate error and backpropagate\n",
        "\n",
        "        error_fake = loss(prediction_fake, fake_data_target(real_data.size(0), device))\n",
        "        total_error = error_real + error_fake\n",
        "        total_error.backward()\n",
        "\n",
        "        # Update weights with gradients\n",
        "        xm.optimizer_step(optimizer)\n",
        "\n",
        "        return total_error, prediction_real, prediction_fake\n",
        "\n",
        "    def train_step_generator(optimizer, fake_data, device):\n",
        "        # Reset gradients\n",
        "        optimizer.zero_grad()\n",
        "        prediction = discriminator(fake_data)\n",
        "        # Calculate error and backpropagate\n",
        "        error = loss(prediction, real_data_target(prediction.size(0), device))\n",
        "        error.backward()\n",
        "        # Update weights with gradients\n",
        "        xm.optimizer_step(optimizer)\n",
        "\n",
        "        # Return error\n",
        "        return error\n",
        "\n",
        "    # Notice the use of .detach() when fake_data is to passed into discriminator\n",
        "    def train_loop_fn(loader):\n",
        "        tracker = xm.RateTracker()\n",
        "        for n_batch, (real_batch,_) in enumerate(loader):\n",
        "            # Train Step Descriminator\n",
        "            real_data = real_batch.to(device)\n",
        "            # sample noise and generate fake data\n",
        "            noise = generator.generate_noise(real_data.size(0)).to(device)\n",
        "            fake_data = generator(noise)\n",
        "            d_error, d_pred_real, d_pred_fake = train_step_discriminator(\n",
        "                d_optimizer, real_data, fake_data.detach(), device)\n",
        "            \n",
        "            #Train Step Generator\n",
        "            noise = generator.generate_noise(real_data.size(0)).to(device)\n",
        "            fake_data = generator(noise)\n",
        "            g_error = train_step_generator(g_optimizer, fake_data, device)\n",
        "        return d_error.item(), g_error.item()\n",
        "\n",
        "\n",
        "    for epoch in range(1, FLAGS['num_epochs'] + 1):\n",
        "        d_error, g_error = train_loop_fn (pl.MpDeviceLoader(train_loader, device))\n",
        "        xm.master_print(\"Finished training epoch {}: D_error:{}, G_error: {}\".format(epoch, d_error, g_error))\n",
        "        \n",
        "    num_test_samples = 24\n",
        "    test_noise = generator.generate_noise(num_test_samples).to(device)\n",
        "    xm.do_on_ordinals(plot_results, generator(test_noise).detach(), (0,))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "svXHYwFv6Nf_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Start training processes\n",
        "def _mp_fn(rank, flags):\n",
        "    global FLAGS\n",
        "    FLAGS = flags\n",
        "    torch.set_default_tensor_type('torch.FloatTensor')\n",
        "    train_gan(rank)\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],\n",
        "          start_method='fork')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "T641nJpEdh4i",
        "colab": {}
      },
      "source": [
        "display_results()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "39wv4mw-Qhxo"
      },
      "source": [
        "## References:\n",
        "[Training GAN from Scratch] (https://github.com/diegoalejogm/gans) <br>\n",
        "[CS231n] (http://cs231n.stanford.edu/slides/2017/cs231n_2017_lecture13.pdf)"
      ]
    }
  ]
}